Classification with KNN and Logistic Regression
Classification with K-Nearest-Neighbors
KNN for Categorical Response
We have existing observations
\[(x_1, C_1), ... (x_n, C_n)\]
where the \(C_i\) are categories .
Given a new observation \(x_{new}\) , how do we predict \(C_{new}\) ?
Predicting a Response with KNN
Find the \(n\) (e.g., 5) values in \((x_1, ..., x_n)\) that are closest to \(x_{new}\)
Let all the closest neighbors “vote” on the category.
Predict \(\widehat{C}_{new}\) = the category with the most votes.
KNN for Classification
To perform classification with K-Nearest-Neighbors, we choose the K closest observations to our target , and we aggregate their response values.
Let’s keep hanging out with the insurance dataset.
Suppose we want to use information about insurance charges to predict whether someone is a smoker or not.
# A tibble: 431 × 6
age sex bmi smoker region charges
<dbl> <chr> <dbl> <chr> <chr> <dbl>
1 19 female 27.9 yes southwest 16885.
2 33 male 22.7 no northwest 21984.
3 32 male 28.9 no northwest 3867.
4 31 female 25.7 no southeast 3757.
5 60 female 25.8 no northwest 28923.
6 25 male 26.2 no northeast 2721.
7 62 female 26.3 yes southeast 27809.
8 56 female 39.8 no southeast 11091.
9 27 male 42.1 yes southeast 39612.
10 23 male 23.8 no northeast 2395.
# ℹ 421 more rows
Step 1: Establish the Model
knn_mod <- nearest_neighbor (neighbors = 5 ) %>%
set_engine ("kknn" ) %>%
set_mode ("classification" )
Notice:
Step 2: Fit our Model
knn_fit_1 <- knn_mod %>%
fit (smoker ~ charges, data = ins)
Error in `check_outcome()`:
! For a classification model, the outcome should be a `factor`, not a `character`.
Step 3: (Re)Fit our Model
knn_fit_1 <- knn_mod %>%
fit (smoker ~ charges, data = ins)
Heck yeah!
knn_fit_1$ fit %>% summary ()
Call:
kknn::train.kknn(formula = smoker ~ charges, data = data, ks = min_rows(5, data, 5))
Type of response variable: nominal
Minimal misclassification: 0.06032
Best kernel: optimal
Best k: 5
Try it!
Open Activity-Classification.qmd .
Select the best KNN model for predicting smoker status.
What metrics does the cross-validation process automatically output?
Ordinary linear regression classification?
lm_mod <- linear_reg () %>%
set_engine ("lm" ) %>%
set_mode ("classification" )
Error in `set_mode()`:
! 'classification' is not a known mode for model `linear_reg()`.
Ordinary linear regression with a dummy variable
Consider the following idea:
Convert the smoker variable to a dummy variable:
ins <- ins %>%
mutate (
smoker_number = case_when (
smoker == "yes" ~ 1 ,
smoker == "no" ~ 0
)
)
Ordinary linear regression with a dummy variable
Fit a linear regression predicting smoker dummy var:
lm_mod <- linear_reg () %>%
set_engine ("lm" ) %>%
set_mode ("regression" )
ins_rec <- recipe (smoker_number ~ charges, data = ins) %>%
step_normalize (all_numeric_predictors ())
ins_wflow <- workflow () %>%
add_recipe (ins_rec) %>%
add_model (lm_mod)
ins_fit <- ins_wflow %>%
fit (ins)
Ordinary linear regression with a dummy variable
Predict each observation to be the smoker closest to the number:
preds <- ins_fit %>% predict (ins)
ins <- ins %>%
mutate (
predicted_num = preds$ .pred,
predicted_smoker = round (predicted_num)
)
Ordinary linear regression
How did we do?
ins %>%
count (predicted_smoker, smoker_number)
# A tibble: 4 × 3
predicted_smoker smoker_number n
<dbl> <dbl> <int>
1 0 0 336
2 0 1 28
3 1 0 8
4 1 1 59
## {background-color=“#B6CADA”}
Residuals
Linear regression assumes that the residuals are Normally distributed . Obviously, they are not here.
Logistic Regression
Solution: How about the same approach, Y is a function of X plus noise, but we let the noise be non-Normal?
\[Y = g^{-1}(\beta_0 + \beta_1 X + \epsilon) \]
for some function \(g\) .
Logistic Regression
Easier way to think of it:
Before:
\[\mu_Y = \beta_0 + \beta_1 X\]
Now:
\[g(\mu_Y) = \beta_0 + \beta_1 X\]
\(g\) is called the link function
Logistic Regression
A common link function is logit function:
\[g(u) = \frac{log(u)}{log(1-u)}\]
In this case, \(u\) represents the probability of someone being a smoker.
Our observations \(Y\) have probability 0 or 1, since we observe them.
Future observations are unknown, so we predict them.
Logistic Regression
In summary:
Given predictors , we try to predict the log-odds of a person being a smoker.
We assume random noise on the relationship between the predictors and the log-odds of the response
From these log-odds , we calculate the probabilities .
We compare the probabilities (between 0 and 1) to the observed truths (0 or 1 exactly).
Step 1: Establish the Model
New model:
logit_mod <- logistic_reg () %>%
set_mode ("classification" ) %>%
set_engine ("glm" )
Step 2: Make a Recipe
Same recipe but sticking with the original (untransformed) smoker variable now:
ins_rec <- recipe (smoker ~ charges, data = ins) %>%
step_normalize (all_numeric_predictors ())
Step 2: Fit our Model
New workflow:
ins_wflow_logit <- workflow () %>%
add_recipe (ins_rec) %>%
add_model (logit_mod)
ins_fit <- ins_wflow_logit %>%
fit (ins)
ins_fit %>% pull_workflow_fit ()
parsnip model object
Call: stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)
Coefficients:
(Intercept) charges
-2.62 3.62
Degrees of Freedom: 430 Total (i.e. Null); 429 Residual
Null Deviance: 434
Residual Deviance: 139 AIC: 143
Step 3: Get our Predictions
Notice: Now our predictions are of the type .pred_class! R did the hard part for us.
preds <- predict (ins_fit, new_data = ins)
preds
# A tibble: 431 × 1
.pred_class
<fct>
1 no
2 yes
3 no
4 no
5 yes
6 no
7 yes
8 no
9 yes
10 no
# ℹ 421 more rows
Log-Odds Predictions
Suppose we wanted to see the predicted log-odds values :
predict (ins_fit, new_data = ins, type = "raw" )
1 2 3 4 5 6 7 8
-1.225256 0.328346 -5.191275 -5.224858 2.442245 -5.540267 2.102734 -2.990489
9 10 11 12 13 14 15 16
5.698586 -5.639630 4.853385 -2.339098 4.471990 -5.699642 -5.667924 8.306921
17 18 19 20 21 22 23 24
-5.441333 -0.084214 -5.285697 5.129131 -1.887326 -5.838252 -1.318829 -2.909468
25 26 27 28 29 30 31 32
-5.902457 -5.530000 -4.367436 -3.951805 6.907207 -2.995833 -2.992969 -5.751803
33 34 35 36 37 38 39 40
-3.253843 0.458793 -1.549477 -4.484697 -5.258837 0.133482 -5.495849 -2.901954
41 42 43 44 45 46 47 48
-5.849704 -5.681580 -2.465043 -5.712592 -5.985026 -5.746103 -5.709225 -3.996180
49 50 51 52 53 54 55 56
-0.286973 -4.252375 0.096566 4.887747 4.643787 -2.285313 -5.853003 -5.499408
57 58 59 60 61 62 63 64
8.505278 -4.881437 -5.718084 -6.022795 -5.869830 -4.648117 -3.431162 -1.829511
65 66 67 68 69 70 71 72
-5.720210 -2.405098 4.943694 -4.191930 -2.791677 -2.140432 -3.214434 1.271953
73 74 75 76 77 78 79 80
-3.150687 4.226450 -0.423825 -5.549963 1.011647 -5.842953 -5.877080 2.625047
81 82 83 84 85 86 87 88
1.087549 -2.515841 -5.764002 6.151122 -4.723896 3.912862 -2.763082 4.665089
89 90 91 92 93 94 95 96
-5.600635 -5.714180 -3.436747 -5.402960 -5.712266 -5.850026 -1.404112 -2.730351
97 98 99 100 101 102 103 104
-5.840031 -3.894190 4.244503 -3.190237 -4.108428 4.443372 -2.845601 1.062306
105 106 107 108 109 110 111 112
-2.278109 -2.670035 -0.594530 -2.301548 -2.109691 -5.870997 -3.648128 -5.286526
113 114 115 116 117 118 119 120
-2.590123 -2.063631 1.126109 -5.809594 -5.879595 -2.151534 -2.282374 -2.704550
121 122 123 124 125 126 127 128
-5.945395 -0.875194 6.518697 -0.038148 -5.984899 -2.757217 -5.401498 -5.717957
129 130 131 132 133 134 135 136
-4.076283 -1.316173 -5.872235 -3.546765 -1.893137 -5.376095 -5.875382 -3.584495
137 138 139 140 141 142 143 144
7.915713 -5.538417 -4.138119 -4.550625 -2.519043 -5.682396 -5.486647 4.927082
145 146 147 148 149 150 151 152
-6.019162 -4.131615 -2.879807 -4.568533 -5.769278 -3.139559 0.376957 -5.872123
153 154 155 156 157 158 159 160
-5.608596 -5.698032 -5.837871 -0.008506 4.338537 -5.597215 -5.901898 -5.813553
161 162 163 164 165 166 167 168
-5.875123 -5.044989 -5.987313 8.523766 -5.836558 -5.700165 -2.538585 -5.770068
169 170 171 172 173 174 175 176
-2.265367 4.658476 -4.286293 -5.441896 -2.879424 -5.608177 -3.517357 -5.986593
177 178 179 180 181 182 183 184
-5.194046 1.447071 -5.152485 -3.363513 -4.724920 -2.871090 -5.848980 -0.515476
185 186 187 188 189 190 191 192
1.961775 -2.138526 -3.252254 -5.390285 7.603120 -2.289831 -5.158960 -2.421881
193 194 195 196 197 198 199 200
-0.922436 -5.170459 -5.609515 -5.720894 -5.223435 -2.645565 -5.862081 -5.869830
201 202 203 204 205 206 207 208
-5.990701 -2.215167 -5.871816 -3.574385 -1.047319 -3.541048 -5.285180 -5.696745
209 210 211 212 213 214 215 216
-5.796558 -2.817629 4.122941 -3.105367 4.177165 -2.673984 -5.230999 6.725791
217 218 219 220 221 222 223 224
-5.343658 -1.978970 -0.230741 -2.364988 -2.582180 -2.638974 -6.023121 1.867903
225 226 227 228 229 230 231 232
7.556509 -5.167895 -5.028931 -4.484943 -4.168874 -5.990870 -3.364376 -4.712389
233 234 235 236 237 238 239 240
-5.728302 -5.652488 -5.394056 -3.025364 -5.753760 -3.462406 7.157471 -4.703171
241 242 243 244 245 246 247 248
-5.764758 -5.620641 -2.668696 -3.360656 -2.415106 -5.693728 5.944522 -0.810446
249 250 251 252 253 254 255 256
-5.224485 -5.436323 -0.416604 0.657663 4.692019 -5.433794 -2.006968 -2.470598
257 258 259 260 261 262 263 264
-0.962161 -4.100138 -2.487672 -5.785210 -2.269297 -4.644306 -0.790873 -2.759736
265 266 267 268 269 270 271 272
-0.914629 -2.008547 5.449052 -5.816115 -4.013782 -5.718319 -6.024612 -5.797211
273 274 275 276 277 278 279 280
-5.503270 10.427941 -5.552567 -5.875216 -2.288857 -4.764907 -2.791484 -5.028169
281 282 283 284 285 286 287 288
-2.817750 -2.523922 2.019011 -3.300830 7.342869 4.985204 6.460195 -5.797998
289 290 291 292 293 294 295 296
6.113650 -1.550304 -3.028654 -6.021584 -3.792615 -5.549775 -5.581539 1.922900
297 298 299 300 301 302 303 304
-5.856720 2.566497 -3.193938 -2.239732 -2.389355 -5.871350 -5.724273 8.459308
305 306 307 308 309 310 311 312
-3.893771 -2.517095 -1.038995 3.907498 -5.615818 4.314727 -2.396046 -2.271380
313 314 315 316 317 318 319 320
-5.053097 -4.465105 -2.265384 -4.130683 -2.642235 -3.478873 -5.693768 -5.873209
321 322 323 324 325 326 327 328
-2.414034 -2.855175 -5.347736 -5.599219 -5.833340 -5.661733 -3.991755 -4.879862
329 330 331 332 333 334 335 336
-4.998280 -1.929933 -2.138293 -4.764464 4.132990 -3.850126 -2.675392 -3.688796
337 338 339 340 341 342 343 344
0.209809 -3.368348 -5.108814 -2.180974 -2.424028 -5.849986 -5.406178 -4.137719
345 346 347 348 349 350 351 352
-1.983958 -5.347680 -3.141664 -0.907080 -5.613488 -5.765632 -2.111944 -5.787584
353 354 355 356 357 358 359 360
-2.346580 4.193709 -3.080170 -3.146969 6.327492 -2.930831 3.960770 -5.859146
361 362 363 364 365 366 367 368
-2.907451 -3.219525 -3.631044 -3.257392 -0.060226 -3.325512 -0.375752 -5.240781
369 370 371 372 373 374 375 376
4.872050 1.892299 9.652724 -5.680333 -4.547576 -2.641622 5.732399 -5.153806
377 378 379 380 381 382 383 384
-5.698837 -4.776598 -3.023458 -2.989604 -5.712994 -5.485987 -0.474146 -5.500856
385 386 387 388 389 390 391 392
-5.567177 -2.151445 -2.370224 -3.727610 -5.109862 -5.743166 -5.850670 -2.442311
393 394 395 396 397 398 399 400
-4.719423 -3.797231 1.589984 -3.141748 -2.732283 -2.912863 -5.546894 -1.967112
401 402 403 404 405 406 407 408
-2.645122 -5.991082 -1.423914 -4.396991 1.833964 -3.108161 4.066335 -5.796431
409 410 411 412 413 414 415 416
-5.448722 -1.176993 -3.030517 -5.533983 -5.041936 -3.845667 -5.232168 -5.907674
417 418 419 420 421 422 423 424
-2.734455 -5.848698 -1.459724 3.958670 -4.976627 -5.841767 2.191879 -2.414500
425 426 427 428 429 430 431
-2.365149 -4.221507 -3.080296 -5.697268 -5.872794 -5.757601 2.508728
Probability Predictions
Suppose we wanted to see the predicted probabilities :
predict (ins_fit, new_data = ins, type = "prob" )
# A tibble: 431 × 2
.pred_no .pred_yes
<dbl> <dbl>
1 0.773 0.227
2 0.419 0.581
3 0.994 0.00553
4 0.995 0.00535
5 0.0800 0.920
6 0.996 0.00391
7 0.109 0.891
8 0.952 0.0479
9 0.00334 0.997
10 0.996 0.00354
# ℹ 421 more rows
Plotting our Logisitic Regression
pred_probs <- predict (ins_fit, new_data = ins, type = "prob" )
ins %>%
mutate (
pred_probs = pred_probs$ .pred_yes
) %>%
ggplot (mapping = aes (y = pred_probs, x = charges, color = smoker)) +
geom_point (alpha = 0.75 ) +
scale_x_continuous (labels = label_dollar ()) +
labs (x = "Charges" ,
y = "" ,
title = "Predicted Probability of Being a Smoker based on Insurance Charges" ,
color = "Smoking Status" )
Plotting our Logisitic Regression
Logistic Regression
How many did we get correct?
preds <- ins_fit %>% predict (ins)
ins <- ins %>%
mutate (
predicted_smoker = preds$ .pred_class
)
ins %>% count (predicted_smoker, smoker)
# A tibble: 4 × 3
predicted_smoker smoker n
<fct> <fct> <int>
1 no no 333
2 no yes 24
3 yes no 11
4 yes yes 63
Logistic Regression
What percentage did we get correct?
ins %>%
mutate (
correct = (predicted_smoker == smoker)
) %>%
count (correct) %>%
mutate (
pct = n/ sum (n)
)
# A tibble: 2 × 3
correct n pct
<lgl> <int> <dbl>
1 FALSE 35 0.0812
2 TRUE 396 0.919
Logistic Regression
What percentage did we get correct?
ins %>%
accuracy (truth = smoker,
estimate = predicted_smoker)
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 accuracy binary 0.919
What if we have a categorical variable where 99% of our values are Category A?
What if we have a categorical variable with more than 2 categories?
What if we want to do a transformation besides logistic?
Are there other ways to do classification besides these logistic regression and KNN ?
Try it!
Open Activity-Classification.qmd again.
Select the best logistic regression model for predicting smoker status.
Report the cross-validated metrics - how do they compare to KNN?